{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "%pylab inline\n",
    "import os, sys\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.linear_model import Ridge\n",
    "\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "tqdm.pandas()\n",
    "sys.path.append('..')\n",
    "import assign_loop_type\n",
    "from assign_loop_type import write_loop_assignments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Modified degscore for Wayment-Steele, ... et al. arXiv 2022\n",
    "\n",
    "The original DegScore predicted all 0's for the first 10 nucleotides due to the padding at runtime. For predicting summed values over entire RNA sequences, this was not apparent, but for comparing nucleotide-to-nucleotide, this resulted in erroneously high MCRMSE values. The above code copies the process of creating the DegScore model but with a `pad` of 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode_input(df, window_size=1, pad=0, seq=True, struct=True, ensemble_size=0):\n",
    "    '''Creat input/output for regression model for predicting structure probing data.\n",
    "    Inputs:\n",
    "    \n",
    "    dataframe (in EternaBench RDAT format)\n",
    "    window_size: size of window (in one direction). so window_size=1 is a total window size of 3\n",
    "    pad: number of nucleotides at start to not include\n",
    "    seq (bool): include sequence encoding\n",
    "    struct (bool): include bpRNA structure encoding\n",
    "    \n",
    "    Outputs:\n",
    "    Input array (n_samples x n_features): array of windowed input features\n",
    "    feature_names (list): feature names\n",
    "    \n",
    "    '''\n",
    "    #MAX_LEN = 68\n",
    "    BASES = ['A','U','G','C']\n",
    "    STRUCTS = ['H','E','I','M','B','S']\n",
    "    \n",
    "    inpts = []\n",
    "    labels = []\n",
    "\n",
    "    feature_kernel=[]\n",
    "    if seq:\n",
    "        feature_kernel.extend(BASES)\n",
    "    if struct:\n",
    "        feature_kernel.extend(STRUCTS)\n",
    "\n",
    "    feature_names = ['%s_%d' % (k, val) for val in range(-1*window_size, window_size+1) for k in feature_kernel]\n",
    "    \n",
    "    for i, row in tqdm(df.iterrows(), desc='Encoding inputs', total=len(df)):\n",
    "        MAX_LEN = row['seq_scored']\n",
    "        \n",
    "        arr = np.zeros([MAX_LEN,len(feature_kernel)])\n",
    "        \n",
    "        if ensemble_size > 0: # stochastically sample ensemble\n",
    "            ensemble = get_ensemble(row['sequence'], n=ensemble_size)\n",
    "        else: # use MEA structure\n",
    "            ensemble = np.array([list(row['predicted_loop_type'])])\n",
    "\n",
    "        for index in range(pad,MAX_LEN):\n",
    "            ctr=0\n",
    "\n",
    "            #encode sequence\n",
    "            if seq:\n",
    "                for char in BASES:\n",
    "                    if row['sequence'][index]==char:\n",
    "                        arr[index,ctr]+=1\n",
    "                    ctr+=1\n",
    "\n",
    "            if struct:\n",
    "                loop_assignments = ''.join(ensemble[:,index])\n",
    "                for char in STRUCTS:\n",
    "                    prob = loop_assignments.count(char) / len(loop_assignments)\n",
    "                    arr[index,ctr]+=prob\n",
    "                    ctr+=1\n",
    "                    \n",
    "        # add zero padding to the side\n",
    "        padded_arr = np.vstack([np.zeros([window_size,len(feature_kernel)]),arr[pad:], np.zeros([window_size,len(feature_kernel)])])\n",
    "\n",
    "        for index in range(pad,MAX_LEN):\n",
    "            new_index = index+window_size-pad\n",
    "            tmp = padded_arr[new_index-window_size:new_index+window_size+1]\n",
    "            inpts.append(tmp.flatten())\n",
    "            labels.append('%s_%d' % (row['id'], index))\n",
    "            \n",
    "    return np.array(inpts), feature_names, labels\n",
    "\n",
    "def encode_output(df, data_type='reactivity', pad=0):\n",
    "    '''Creat input/output for regression model for predicting structure probing data.\n",
    "    Inputs:\n",
    "    \n",
    "    dataframe (in EternaBench RDAT format)\n",
    "    data_type: column name for degradation\n",
    "    window_size: size of window (in one direction). so window_size=1 is a total window size of 3\n",
    "    pad: number of nucleotides at start to not include\n",
    "    \n",
    "    Outputs:\n",
    "    output array (n_samples): array of reactivity values\n",
    "    \n",
    "    '''\n",
    "    #MAX_LEN = 68\n",
    "    \n",
    "    outpts = []\n",
    "    labels = []\n",
    "    # output identity should be in form id_00073f8be_0\n",
    "\n",
    "    for i, row in df.iterrows():\n",
    "        MAX_LEN = row['seq_scored']\n",
    "        \n",
    "        for index in range(pad,MAX_LEN):\n",
    "            outpts.append(row[data_type][index])\n",
    "            labels.append('%s_%d' % (row['id'], index))\n",
    "            \n",
    "            \n",
    "    return outpts, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "kaggle_train = pd.read_json('train.json',lines=True)\n",
    "kaggle_train = kaggle_train.loc[kaggle_train['SN_filter']==1]\n",
    "\n",
    "kaggle_test = pd.read_json('test.json',lines=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['index', 'id', 'sequence', 'structure', 'predicted_loop_type',\n",
       "       'signal_to_noise', 'SN_filter', 'seq_length', 'seq_scored',\n",
       "       'reactivity_error', 'deg_error_Mg_pH10', 'deg_error_pH10',\n",
       "       'deg_error_Mg_50C', 'deg_error_50C', 'reactivity', 'deg_Mg_pH10',\n",
       "       'deg_pH10', 'deg_Mg_50C', 'deg_50C'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kaggle_train.keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Encode data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###### Max. expected accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "92b1b650530a453b9c0c3b3071a40f35",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Encoding inputs', max=1589, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1d10e9999f464016bbc8befb71d81004",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Encoding inputs', max=3634, style=ProgressStyle(description_w…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "inputs_train, feature_names, _ = encode_input(kaggle_train, window_size=12)\n",
    "inputs_test, _, test_labels = encode_input(kaggle_test, window_size=12)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Visualize encoding for an example nucleotide"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5, 0, 'window position')"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARsAAACiCAYAAABxn2koAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAQdklEQVR4nO3de7RcZXnH8e/PpJhA5BISaYCEIA13KeUctALK1RYpLKC1QmDJpS1UxVJbU0vRIixkqV2IWg2yImLCHYpiKaSsUDUEbbicA4EAAQpCuCSFJFwj9/D0j/0OnTWZk+w5Z/aemZ3fZ61ZZ99m72efOedZ7/vu931HEYGZWdHe0+kAzGzD4GRjZqVwsjGzUjjZmFkpnGzMrBRONmZWCicb6wmSzpZ0eVqeImm1pFGdjsvyc7LpQZKekPSmpAkN2xdJCklT0/rsdNzqutcxded4rWHf98u/m9ZFxJMRMS4i1nQ6FsvPyaZ3PQ5Mr61I+iAwtslx/5L+MWuva+r2HdGw7/NFB20bLieb3nUZcELd+onApUVcSNJ7JJ0h6TFJqyRdK2l82jc1laZOlPSkpJWSvlz33lGSzkzvfUXSoKTJad8+ku6S9FL6uU/d+7aXdGt6zy3AhLp9tWuOTuvzJZ0r6dfp+Hn1pT5JJ0hammL/51SqO6SI35UNzcmmd90ObCppl9R2cQxweUHXOh04Ctgf2Bp4AZjZcMx+wE7AwcBZknZJ2/+erAR2GLAp8BfAqylZ3QT8K7AlcAFwk6Qt0/uuBAbJksy5ZMl0XY4DTgbeD2wEzACQtCtwIXA8MAnYDNimpbu39ogIv3rsBTwBHAJ8Bfg6cChwCzAaCGBqOm428DrwYnqtbDjH6rp9LwKnDHG9JcDBdeuTgLfS9aama25bt/9O4Ni0/DBwZJNzfhq4s2HbQuAkYArwNrBJ3b4rgcvTcu2ao9P6fOArdcd+Drg5LZ8FXFW3b2PgTeCQTn+OG9pr9DBzlHWHy4AFwPYMXYU6PyK+MsS+oyLiv3JcZzvgeknv1G1bA2xVt/6/dcuvAuPS8mTgsSbn3BpY2rBtKVmpY2vghYj4bcO+yeuIcajrbw08VdsREa9KWrWO81hBXI3qYRGxlKyh+DDgpwVe6ingExGxed1rTEQ8k/O9OzTZvowsidWbAjwDLAe2kLRJw77hWA5sW1uRNJas2mYlc7LpfX8JHNRQCmi3i4DzJG0HIGmipCNzvvdi4FxJ05TZI7XLzAV2lHScpNHpkfyuwI0piQ4A50jaSNJ+wBHDjP064IjUGL0RcA6gYZ7LRsDJpsdFxGMRMTDMt/9HQz+b64c47rvADcA8Sa+QNU5/OOc1LgCuBeYBLwM/AsZGxCrgcOCLwCrgS8DhEbEyve+4dI3nga8yzCdtEfEA8DfA1WSlnFeA54A3hnM+Gz6lRjOzDYKkcWSN4dMi4vFOx7MhccnGKk/SEZI2Tm1A5wOLyZ7GWYmcbGxDcCRZg/QyYBrZY3kX6UvmapSZlcIlGzMrRaGd+iTlKjb19fUVGcY6DQ4O5jqukzG2W7ffc974WtHtn1+777mV+233tSOiadeCQqtReZNNJ6tyUr4uF1Wqbnb7PeeNrxXd/vm1+55bud8Crt30hC1VoyQdnUbb7tyesMxsQ9Fqm8104FfAsQXEYmYVljvZpM5Q+5J1j3eyMbOWtFKyOYps2P4jwPOS9mp2kKRTJQ1IGm4XejOroNwNxJJuAr4TEbdIOh2YHBH/sJ73uIG4C3X7PbuBeOS6sYE4V7JJo3SfJhvAFsCo9HO7dfXEdLLpTt1+z042I9eNySZvNeqTwKURsV1ETI2IyWTzqOzXrgDNrNryJpvpQOP0Az8hmwbAzGy9Cu3U19/fHwMD5bcTt1IsbKHNarjhjEgR8XV79cjV6pHr1N9Df38/AwMDI+/UZ2Y2XLmSTfqenvsbtp0taUYxYZlZ1bhkY2alcLIxs1I42ZhZKfImm6Gaq9faXj9cYcWKFcOPzMwqJW+yWQVs0bBtPLCy8cCImBUR/RHRP3HixJHGZ2YVkSvZRMRqYLmkgwHSl8IfSjbdhJnZerUyLegJwExJ30rr50REs+9wNjNbS+5kExEPAgcWGMt6dbJ3Z7f3GC0ivk4ODuyUTvUob/fvpht/134aZWalcLIxs1K0Mi3oVpKulPQbSYOSFko6usjgzKw68o6NEvAzYEFEfCAi+sjmId62yODMrDryNhAfBLwZERfVNkTEUuB7hURlZpWTtxq1G3B3ngPdg9jMmhlWA7GkmZLulXRX4z73IDazZvImmweAd7+6JSJOAw4GnE3MLJe8yeYXwBhJn63btnEB8ZhZReUdGxVkX1K3v6THJd0JzAH+scjgzKw6WhmusJyCvna3KpNMF8FDNNqj3b/HKv1uyuIexGZWilZ6EP+upKslPSbpQUlzJe1YZHBmVh2t9CC+HpgfETtExK7AmcBWRQZnZtWRt83mQOCthh7Ei4oJycyqKG81andgMM+B7kFsZs20vYHYPYjNrJlWehD3FRmImVVbKz2I3yvplNoGSXtL2r+YsMysalrpQXw08PH06PsB4GxgWYGxmVmFtNKDeBnwqSKCcG/MofXC76bdk38XoRd+j1XnHsRmVoq8nfpWN6yfJOn7xYRkZlXkko2ZlcLJxsxKkbeBeKyk+uEJ44Ebmh0o6VTgVIApU6aMLDozq4y8JZvXImLP2gs4a6gD3YPYzJpxNcrMSuFkY2alcLIxs1LkaiCOiHEN67OB2QXEYz3IvXMtD5dszKwUucdGAUhaAyyu23R1RHyjvSGZWRW1lGxIj8ALicTMKs3VKDMrRavJZqykRXWvYxoP8BzEZtZM26tRETELmAXQ39/vxxRmBrgaZWYlcbIxs1K0Wo1qHP19c0Sc0c6AzKyaWko2ETGqleMHBwfbOj9tET1V2z1/rnvTWjcoYl7okf5tuxplZqUYdrJpnJfYzGxdXLIxs1I42ZhZKVp9GrVe9XMQm5nVtD3Z1PcgluRHM2YGuBplZiVxsjGzUjjZmFkpht1m0zgvcTN9fX0MDAwM9xKlcI9fq6Ju/Lt2ycbMSrHeZCMpJF1Wtz5a0gpJNxYbmplVSZ6SzW+B3SWNTesfB54pLiQzq6K81aj/BP4kLU8HriomHDOrqrzJ5mrgWEljgD2AO4Y60HMQm1kzuZJNRNwHTCUr1cxdz7GzIqI/IvonTpw48gjNrBJaefR9A3A+cACwZSHRmFlltZJsLgFeiojFkg4oKB4zq6jcySYinga+W2AsZlZh6002zXoKR8R8YH4B8ZhZRbkHsZmVIneykbQmfeXuvZLulrRPkYGZWbW00kD87lfvSvpj4OvA/oVEZWaVM9xq1KbAC+0MxMyqrZWSTe3bMMcAk4CDignJzKqolZLNaxGxZ0TsDBwKXKomX7vn4Qpm1sywqlERsRCYAKw1HsHDFcysmWElG0k7A6OAVe0Nx8yqajhtNgACToyINQXEZGYV1MpwhVFFBmJm1db2L6krUpP26BFr98TQRcTY7bpxcu1GG9rnUsRnMtLfoYcrmFkpWko2kr4s6QFJ96WhCx8uKjAzq5bc1ShJHwEOB/aKiDckTQA2KiwyM6uUVtpsJgErI+INgIhYWUxIZlZFrVSj5gGTJT0i6UJJTQdhugexmTWTO9lExGqgDzgVWAFcI+mkJse5B7GZraWlR9+pE998YL6kxcCJwOz2h2VmVdPK5Fk7SZpWt2lPYGn7QzKzKmqlZDMO+J6kzYG3gUfJqlRmZuulInt/SlrB2qWfCUBVnmRV5V6qch/ge+m07SKiaWNtocmm6QWlgYjoL/WiBanKvVTlPsD30s08XMHMSuFkY2al6ESymdWBaxalKvdSlfsA30vXKr3Nxsw2TK5GmVkpnGzMrBSlJRtJh0p6WNKjks4o67pFkPSEpMVpTp+BTsfTCkmXSHpO0v1128ZLukXS/6SfW3QyxryGuJezJT2TPptFkg7rZIx5SJos6ZeSlqT5ov42be/Jz2UopSQbSaOAmcAngF2B6ZJ2LePaBTowfY9Wr/WDmE32vV/1zgB+HhHTgJ+n9V4wm7XvBeDb6bPZMyLmlhzTcLwNfDEidgH+EDgt/X/06ufSVFklmw8Bj0bEbyLiTeBq4MiSrm11ImIB8HzD5iOBOWl5DnBUqUEN0xD30nMiYnlE3J2WXwGWANvQo5/LUMpKNtsAT9WtP5229aoA5kkalFSF8WFbRcRyyP7wgfd3OJ6R+nyauvaSXqt6SJoK/AFwBxX7XMpKNs2mZe/lZ+77RsReZNXC0yR9rNMB2bt+AOxANivBcuBbnQ0nP0njgJ8AX4iIlzsdT7uVlWyeBibXrW8LLCvp2m0XEcvSz+eA68mqib3sWUmTANLP5zocz7BFxLMRsSYi3gF+SI98NpJ+hyzRXBERP02bK/O5QHnJ5i5gmqTtJW0EHAvcUNK120rSJpLeV1sG/gi4f93v6no3kE2ERvr57x2MZURq/5zJ0fTAZ6PsC5l+BCyJiAvqdlXmc4ESexCnR5DfIfuO8Esi4rxSLtxmkj5AVpqBbD6gK3vpXiRdBRxANn3Bs8BXgZ8B1wJTgCeBP4+Irm94HeJeDiCrQgXwBPDXtXaPbiVpP+A2YDHwTtp8Jlm7Tc99LkPxcAUzK4V7EJtZKZxszKwUTjZmVgonGzMrhZONmZXCyaaiJM1NX7uT9/ip9aOnu4Wkz0g6IS2fJGnrun0XV2BA7wbDj74NeHdMzo0RsXuHQxmSpPnAjIjoqWk9LOOSTQ+S9CVJp6flb0v6RVo+WNLlafkJSRNSiWWJpB+muVLmSRqbjumTdK+khcBpdecfI+nHac6eeyQdmLbPlbRHWr5H0llp+VxJf9UQ41RJD0makwZFXidp47o470nnv0TSe9P2b0h6MB1/ftp2tqQZkj4J9ANXpHlqxkqaL6k/HTc9ne9+Sd+si2O1pPPSfd4uaasCPhLLwcmmNy0APpqW+4FxaWxNrSdqo2nAzIjYDXgR+LO0/cfA6RHxkYbjTwOIiA8C04E5ksbUritpU7I5WPZNxw913Z2AWRGxB/Ay8Ll0ntnAMen8o4HPShpPNrxgt3T81+pPFBHXAQPA8Wmemtdq+1LV6pvAQWS9h/eWVJuOYRPg9oj4/RT/KU3itBI42fSmQaAvjdF6A1hIlnQ+SvN/+scjYlHde6dK2gzYPCJuTdsvqzt+v9p6RDxE9q2mO6Zzfyztv4ksyW0MTI2Ih5tc96mI+HVavjy9b6cUzyNp+5x0zpeB14GLJf0p8GreXwawNzA/IlZExNvAFemcAG8CN9bfewvntTZysulBEfEW2bifk4H/JksCB5JNrbCkyVveqFteQ1aaEENP89FsShDIBtTWktoC4B6yksLgUKE2WW967pQkPkQ28vko4OYhztlKvABvxf83TNbu3TrAyaZ3LQBmpJ+3AZ8BFkXOFv+IeBF4KQ0CBDi+4dzHA0jakWwg4MNplsWngE8Bt6frzqB5aQpgiqRaFW068CvgIbKS1e+l7Z8Gbk1zuWyWpvH8All1qNErwPuabL8D2D+1UY1K17q1yXHWQU42ves2YBKwMCKeJauCDPVPP5STgZmpgfi1uu0XAqMkLQauAU6KiFrp6Dbg2Yh4NS1vu47rLgFOlHQfMB74QUS8nq77b+n87wAXkSWRG9OxtwJ/1+R8s4GLag3EtY1pVPc/Ab8E7gXujoieno6hivzo2wrRC4/SrVwu2ZhZKVyyMbNSuGRjZqVwsjGzUjjZmFkpnGzMrBRONmZWiv8DOe86sM5fQXQAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 720x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "figure(figsize=(10,4))\n",
    "subplot(1,2,1)\n",
    "title('MFE encoding')\n",
    "imshow(np.array(inputs_train[33].reshape(25,10)).T,cmap='gist_heat_r')\n",
    "yticks(range(10), ['A','U','G','C','H','E','I','M','B','S'])\n",
    "xlabel('window position')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### To set up kaggle submission format:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_submission = pd.read_csv('sample_submission.csv.zip')\n",
    "mask = sample_submission['id_seqpos'].isin(test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Train models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Model based on single MFE structure (primary DegScore model used)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting reactivity ...\n",
      "Fitting deg_Mg_pH10 ...\n",
      "Fitting deg_pH10 ...\n",
      "Fitting deg_Mg_50C ...\n",
      "Fitting deg_50C ...\n"
     ]
    }
   ],
   "source": [
    "for output_type in ['reactivity', 'deg_Mg_pH10', 'deg_pH10', 'deg_Mg_50C','deg_50C']:\n",
    "    outputs_train, outputs_labels = encode_output(kaggle_train, data_type=output_type)\n",
    "\n",
    "    reg = Ridge(alpha=0.15, fit_intercept=False)\n",
    "    print('Fitting %s ...' % output_type)\n",
    "    #reg.fit(mea_inputs_train_construct, mea_outputs_train)\n",
    "    reg.fit(inputs_train, outputs_train)\n",
    "        \n",
    "    test_prediction = reg.predict(inputs_test)\n",
    "    sample_submission.loc[mask, output_type] = test_prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_submission.to_csv('mar82022_v1.csv',index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
